from pyspark.sql import SparkSession
from pyecharts.charts import Bar, Line
import pyecharts.options as opts

if __name__ == '__main__':
    spark = SparkSession.Builder().appName("Jobs5").master("yarn").getOrCreate()
    df = spark.read.json("hdfs://node1:8020/input/datasets/51jobs.json")
    df.createOrReplaceTempView("51Jobs")
    df1 = spark.sql("select a.city,b.degreeString"
                    "      ,sum(case when c.yearSalary is null then 0 else 1 end) as count"
                    "  from (select city from 51Jobs group by city) a"
                    "  left join (select degreeString from 51Jobs group by degreeString) b"
                    "  left join 51Jobs c"
                    "    on a.city = c.city"
                    "   and b.degreeString = c.degreeString"
                    " where a.city in ('北京','上海','广州','深圳','南京','杭州','武汉','成都','苏州','天津','无锡','西安')"
                    " group by a.city,b.degreeString")
    df1.cache()
    df2 = spark.sql("select a.city,b.degreeString"
                    "      ,mean(case when c.yearSalary is null then 0 else c.yearSalary end) as mean"
                    "  from (select city from 51Jobs group by city) a"
                    "  left join (select degreeString from 51Jobs group by degreeString) b"
                    "  left join 51Jobs c"
                    "    on a.city = c.city"
                    "   and b.degreeString = c.degreeString"
                    " where a.city in ('北京','上海','广州','深圳','南京','杭州','武汉','成都','苏州','天津','无锡','西安')"
                    " group by a.city,b.degreeString")
    df2.cache()

    xaxis_data = df1.groupBy("city").count().orderBy("city").select("city").rdd.flatMap(lambda x: x).collect()

    b1 = df1.select("count").where("degreeString = '大专'").orderBy("city").rdd.flatMap(lambda x: x).collect()
    b2 = df1.select("count").where("degreeString = '本科'").orderBy("city").rdd.flatMap(lambda x: x).collect()
    b3 = df1.select("count").where("degreeString = '硕士'").orderBy("city").rdd.flatMap(lambda x: x).collect()
    b4 = df1.select("count").where("degreeString = '博士'").orderBy("city").rdd.flatMap(lambda x: x).collect()

    l1 = df2.select("mean").where("degreeString = '大专'").orderBy("city").rdd.flatMap(lambda x: x).collect()
    l2 = df2.select("mean").where("degreeString = '本科'").orderBy("city").rdd.flatMap(lambda x: x).collect()
    l3 = df2.select("mean").where("degreeString = '硕士'").orderBy("city").rdd.flatMap(lambda x: x).collect()
    l4 = df2.select("mean").where("degreeString = '博士'").orderBy("city").rdd.flatMap(lambda x: x).collect()

    bar = Bar().add_xaxis(xaxis_data) \
        .add_yaxis("大专", b1).add_yaxis("本科", b2).add_yaxis("硕士", b3).add_yaxis("博士", b4) \
        .extend_axis(yaxis=opts.AxisOpts(name = "年薪（万）")) \
        .set_series_opts(label_opts=opts.LabelOpts(is_show=False)) \
        .set_global_opts(
            xaxis_opts=opts.AxisOpts(
                axislabel_opts={"interval": "0", "rotate": 45}
            ),
            yaxis_opts=opts.AxisOpts(max_=1000,name = "岗位数")
        )

    line = Line().add_xaxis(xaxis_data) \
        .add_yaxis("大专", l1, yaxis_index=1).add_yaxis("本科", l2, yaxis_index=1) \
        .add_yaxis("硕士", l3, yaxis_index=1).add_yaxis("博士", l4, yaxis_index=1) \
        .set_series_opts(label_opts=opts.LabelOpts(is_show=False))

    bar.overlap(line)
    bar.render()
